import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
import random
import numpy as np
import copy
from seqeval.metrics import f1_score, accuracy_score
from transformers import AutoTokenizer

from src.utils_data import *
from src.utils_others import *
from src.trainer import BaseTrainer

logger = logging.getLogger()
params = get_params()
if 'bert' in params.backbone:
    auto_tokenizer = AutoTokenizer.from_pretrained(params.backbone)
else:
    auto_tokenizer = None
pad_token_label_id = nn.CrossEntropyLoss().ignore_index

# ============================================= Probing and Tracking =================================================
def probe_model(trainer: BaseTrainer, cl_dataset: Continual_Dataset, task_id: int, phase: str) -> dict:
    '''
        Probing the model in the trainer on the all seen task

        Params: 
            - trainer: the model
            - cl_dataset: the continual dataset
            - task_id: the task id
            - phase: 'train','dev'or'test'

        Return:
            - result_dict

    '''
    tmp_model = copy.deepcopy(trainer.model)
    tmp_model.encoder.eval()
    cur_model = trainer.model
    tmp_model.classifier.reset_parameters()
    tg_params =[{'params': tmp_model.classifier.parameters(), 'lr': 0.1, 'weight_decay': 0.}]
    tmp_optimizer = torch.optim.SGD(tg_params)
    cur_optimizer = trainer.optimizer

    trainer.model = tmp_model
    trainer.optimizer = tmp_optimizer

    accum_train_loader = cl_dataset.get_accum_data_loader(task_id, 'train')

    if trainer.params.backbone=='resnet18':
        probing_epoch = 15
    elif trainer.params.task_name=='IC':
        probing_epoch = 1
    else:
        probing_epoch = 5

    for e in range(probing_epoch):

        for idx, X, y in accum_train_loader:

            X, y = X.cuda(), y.cuda()
            logits = trainer.model.forward(X)
            total_loss, _, _ = trainer.model.batch_loss(logits,y)
            trainer.optimizer.zero_grad()        
            total_loss.backward()
            trainer.optimizer.step()

    if trainer.params.task_name == 'NER':
        test_result = evaluate_all_seen_task_ner(trainer, cl_dataset, task_id, phase)
    elif trainer.params.task_name == 'TC':
        test_result = evaluate_all_seen_task_tc(trainer, cl_dataset, task_id, phase)
    elif trainer.params.task_name == 'IC':
        test_result = evaluate_all_seen_task_ic(trainer, cl_dataset, task_id, phase)
    else:
        raise NotImplementedError()
    
    trainer.model = cur_model
    trainer.optimizer = cur_optimizer

    return test_result

def tracking_model(trainer: BaseTrainer, cl_dataset: Continual_Dataset, task_id: int) -> dict:
    '''
        Compute the class center for linear classifier and Transformer encoder

        Params:
            - trainer: the model
            - cl_dataset: the continual dataset
            - task_id: the task id

        Return:
            - cls_center: the class center of classifier
            - encoder_center: the class center of encoder
    '''

    num_class = cl_dataset.ACCUM_NUM_CLASS[task_id]

    accum_train_loader = cl_dataset.get_accum_data_loader(task_id, 'train')
    if trainer.params.task_name == 'NER':
        encoder_center = compute_class_feature_center(accum_train_loader, trainer.model, list(range(1,num_class)), True)
        if task_id==0:
            cls_center = trainer.model.classifier.weight.data[1:].clone().detach().cpu() # remove O class
            cls_center = F.normalize(cls_center, p=2, dim=1)
        elif task_id>0:
            cls_center_0 = trainer.model.classifier.fc0.weight.data[1:].clone().detach().cpu() # remove O class
            cls_center_1 = trainer.model.classifier.fc1.weight.data.clone().detach().cpu()
            cls_center_0 = F.normalize(cls_center_0, p=2, dim=1)
            cls_center_1 = F.normalize(cls_center_1, p=2, dim=1)
            cls_center = torch.cat((cls_center_0,cls_center_1),dim=0)
    elif trainer.params.task_name in ['IC','TC']:
        encoder_center = compute_class_feature_center(accum_train_loader, trainer.model, list(range(num_class)), True)
        if task_id==0:
            cls_center = trainer.model.classifier.weight.data.clone().detach().cpu()
            cls_center = F.normalize(cls_center, p=2, dim=1)
        elif task_id>0:
            cls_center_0 = trainer.model.classifier.fc0.weight.data.clone().detach().cpu()
            cls_center_1 = trainer.model.classifier.fc1.weight.data.clone().detach().cpu()
            cls_center_0 = F.normalize(cls_center_0, p=2, dim=1)
            cls_center_1 = F.normalize(cls_center_1, p=2, dim=1)
            cls_center = torch.cat((cls_center_0,cls_center_1),dim=0)
    else:
        raise NotImplementedError()

    return cls_center, encoder_center
# ==================================================================================================================


# ============================================= For Evaluating NER =================================================
def compute_random_result_ner(trainer: BaseTrainer, cl_dataset: Continual_Dataset, task_id: int, phase: str) -> dict:
    '''
        Compute the f1 of random guess on all seen tasks

        Params: 
            - trainer: the model
            - cl_dataset: the continual dataset
            - task_id: the task id
            - phase: 'train','dev'or'test'

        Return:
            - random_result: {
                    'micro_f1': (NUM_TASK,),
                    'macro_f1': (NUM_TASK,),
                }
    '''
    assert phase in ['train','test','dev']

    random_result = {}
    mif1_list = []
    maf1_list = []

    for t_id in range(task_id+1):
        
        y_list = []
        
        with torch.no_grad():

            for idx, x, y in cl_dataset.data_loader[phase][t_id]: 
                y = y.flatten().detach().cpu()
                y_list.append(y)
            
        y_list = torch.cat(y_list)

        ### calcuate f1 score
        gold_line = []
        for gold_index in y_list:
            gold_index = int(gold_index)
            if gold_index != pad_token_label_id:
                gold_token = cl_dataset.LABEL_LIST[gold_index]
                gold_line.append(gold_token) 
        gold_label_set = list(set(gold_line))
        pred_line = [random.choice(gold_label_set) for _ in range(len(gold_line))]

        mif1, maf1 = evaluate_ner(gold_line, pred_line)

        mif1_list.append(mif1)
        maf1_list.append(maf1)

    random_result['micro_f1'] = mif1_list
    random_result['macro_f1'] = maf1_list

    return random_result 

def evaluate_all_seen_task_ner(trainer: BaseTrainer, cl_dataset: Continual_Dataset, task_id: int, phase: str, is_mbpa: bool=False) -> dict:
    '''
        Evaluate the model in the trainer on all seen tasks

        Params: 
            - trainer: the model
            - cl_dataset: the continual dataset
            - task_id: the task id
            - phase: 'train','dev'or'test'
            - is_mbpa: if using test time adaptation

        Return:
            - result_dict: {
                    'Result_test_mif1_0': 90,
                    'Result_test_maf1_0': 80,
                    'Result_test_classf1_0': {...},
                    'Result_test_mif1_1': 95,
                    'Result_test_maf1_1': 85,
                    'Result_test_classf1_1': {...},
                    'Result_test_mean_mif1': 92.5,
                    'Result_test_mean_maf1': 87.5,
                }

    '''
    assert phase in ['train','test','dev']

    result_dict = {}
    mif1_list = []
    maf1_list = []

    for t_id in range(task_id+1):

        mif1, maf1, ordered_f1_score_dict = evaluate_current_task_ner(trainer, cl_dataset, t_id, phase, is_mbpa)

        result_dict['Result_%s_mif1_%d'%(phase,t_id)] = mif1
        result_dict['Result_%s_maf1_%d'%(phase,t_id)] = maf1
        result_dict['Result_%s_classf1_%d'%(phase,t_id)] = ordered_f1_score_dict

        mif1_list.append(mif1)
        maf1_list.append(maf1)

    result_dict['Result_%s_mean_mif1'%(phase)] = np.mean(mif1_list)
    result_dict['Result_%s_mean_maf1'%(phase)] = np.mean(maf1_list)

    return result_dict 

        
def evaluate_current_task_ner(trainer: BaseTrainer, cl_dataset: Continual_Dataset, task_id: int, phase: str, is_mbpa:bool=False) -> dict:
    '''
        Evaluate the model in the trainer on the current task

        Params: 
            - trainer: the model
            - cl_dataset: the continual dataset
            - task_id: the task id
            - phase: 'train','dev'or'test'
            - is_mbpa: if using test time adaptation

        Return:
            - mif1: micro f1
            - maf1: macro f1
            - ordered_f1_score_dict: the f1 of each class
    '''

    assert phase in ['train','test','dev']

    trainer.model.eval()

    if is_mbpa and task_id>0:
        y_list, pred_list = trainer.meta_mbpa_predict(cl_dataset.data_loader[phase][task_id], task_id)

    else:
        y_list = []
        pred_list = []

        with torch.no_grad():
            for idx, x, y in cl_dataset.data_loader[phase][task_id]: 
                x, y = x.cuda(), y.cuda()
                logits = trainer.model.forward(x)
                pred_list.append(logits.view(-1, logits.shape[-1]).detach().cpu().argmax(-1))
                y_list.append(y.flatten().detach().cpu())
        
        y_list = torch.cat(y_list)
        pred_list = torch.cat(pred_list)

    ### calcuate f1 score
    pred_line = []
    gold_line = []
    for pred_index, gold_index in zip(pred_list, y_list):
        gold_index = int(gold_index)
        if gold_index != pad_token_label_id:
            pred_token = cl_dataset.LABEL_LIST[pred_index]
            gold_token = cl_dataset.LABEL_LIST[gold_index]
            # lines.append("w" + " " + pred_token + " " + gold_token)
            pred_line.append(pred_token) 
            gold_line.append(gold_token) 

    trainer.model.train()

    mif1, maf1, ordered_f1_score_dict = evaluate_ner(gold_line,pred_line,cl_dataset.CUR_ENTITY[task_id])

    return mif1, maf1, ordered_f1_score_dict

def evaluate_ner(gold_line,pred_line,entity_list=None):
    '''
        Evaluate the micro f1, macro f1, each classes' f1 for NER

        Params:
            - gold_line: the ground truth labels
            - pred_line: the predicted labels
            - entity_list: the entity in the gold label set (default=None)

        Return:
            - mif1: micro f1
            - maf1: macro f1
            - ordered_f1_score_dict: the f1 of each class (if entity_list is not None)
    '''
    gold_line,pred_line = align_label_set(gold_line,pred_line)

    # compute overall f1 score
    # micro f1 (default)
    mif1 = f1_score([gold_line], [pred_line])*100
    # macro f1 (average of each class f1)
    maf1 = f1_score([gold_line], [pred_line], average='macro')*100

    if entity_list is not None:
        # compute f1 score for each class
        f1_list = f1_score([gold_line], [pred_line], average=None)
        f1_list = list(np.array(f1_list)*100)
        entity_set = sorted(entity_list)
        assert len(entity_set)==len(f1_list)
        f1_score_dict = dict()
        for e, s in zip(entity_set,f1_list):
            f1_score_dict[e] = round(s,2)
        ordered_f1_score_dict = dict()
        for e in entity_list:
            ordered_f1_score_dict[e] = f1_score_dict[e]
        return mif1, maf1, ordered_f1_score_dict
    else:
        return mif1, maf1

def align_label_set(gold_line, pred_line):
    ''' 
        Check whether the label set are the same,
        ensure that the predict label set is the subset of the gold label set.
        Note that we set the prediction to 'O' when the predicted label is not in the gold label set,
        and it does not change the accuracy or recall of classes in the gold label set.

        Params:
            - gold_line: a list for ground-truth label
            - pred_line: a list for predicted label

        Return:
            - gold_line: a list for ground-truth label
            - pred_line: a list for predicted label
    '''
    gold_label_set, pred_label_set = np.unique(gold_line), np.unique(pred_line)
    if set(gold_label_set)!=set(pred_label_set):
        O_label_set = []
        for e in pred_label_set:
            if e not in gold_label_set:
                O_label_set.append(e)
        if len(O_label_set)>0:
            # map the predicted labels which are not seen in gold label set to 'O'
            for i, pred in enumerate(pred_line):
                if pred in O_label_set:
                    pred_line[i] = 'O'

    return gold_line, pred_line

# ==================================================================================================================


# ============================================= For Evaluating TC =================================================
def compute_random_result_tc(trainer: BaseTrainer, cl_dataset: Continual_Dataset, task_id: int, phase: str) -> dict:
    '''
        Compute the acc of random guess on all seen tasks

        Params: 
            - trainer: the model
            - cl_dataset: the continual dataset
            - task_id: the task id
            - phase: 'train','dev'or'test'

        Return:
            - random_result: {
                    'acc': (NUM_TASK,),
                }
    '''
    assert phase in ['train','test','dev']

    random_result = {}
    acc_list = []

    for t_id in range(task_id+1):
        
        y_list = []
        with torch.no_grad():
            for idx, x, y in cl_dataset.data_loader[phase][t_id]: 
                y_list.append(y.detach().cpu()) 
        y_list = torch.cat(y_list)

        ### calcuate acc score
        gold_line = list(y_list.numpy())
        pred_line = [random.choice(cl_dataset.CUR_CLASS[task_id]) for _ in gold_line]

        acc = evaluate_tc(gold_line, pred_line)

        acc_list.append(acc)

    random_result['acc'] = acc_list

    return random_result 

def evaluate_all_seen_task_tc(trainer: BaseTrainer, cl_dataset: Continual_Dataset, task_id: int, phase: str, is_mbpa: bool=False) -> dict:
    '''
        Evaluate the model in the trainer on all seen tasks

        Params: 
            - trainer: the model
            - cl_dataset: the continual dataset
            - task_id: the task id
            - phase: 'train','dev'or'test'
            - is_mbpa: if using test time adaptation

        Return:
            - result_dict: {
                    'Result_test_acc_0': 90,
                    'Result_test_classacc_0': {...},
                    'Result_test_acc_1': 95,
                    'Result_test_classacc_1': {...},
                    'Result_test_mean_acc': 92.5,
                }

    '''
    assert phase in ['train','test','dev']

    result_dict = {}
    acc_list = []

    for t_id in range(task_id+1):

        acc, ordered_acc_score_dict = evaluate_current_task_tc(trainer, cl_dataset, t_id, phase, is_mbpa)

        result_dict['Result_%s_acc_%d'%(phase,t_id)] = acc
        result_dict['Result_%s_classacc_%d'%(phase,t_id)] = ordered_acc_score_dict

        acc_list.append(acc)

    result_dict['Result_%s_mean_acc'%(phase)] = np.mean(acc_list)

    return result_dict 

        
def evaluate_current_task_tc(trainer: BaseTrainer, cl_dataset: Continual_Dataset, task_id: int, phase: str, is_mbpa: bool=False) -> dict:
    '''
        Evaluate the model in the trainer on the current task

        Params: 
            - trainer: the model
            - cl_dataset: the continual dataset
            - task_id: the task id
            - phase: 'train','dev'or'test'
            - is_mbpa: if using test time adaptation

        Return:
            - acc: accuracy
            - ordered_acc_score_dict: the acc of each class
    '''

    assert phase in ['train','test','dev']

    trainer.model.eval()

    if is_mbpa and task_id>0:
        gold_line, pred_line = trainer.meta_mbpa_predict(cl_dataset.data_loader[phase][task_id], task_id)
    else:
        y_list = []
        pred_list = []
        with torch.no_grad():
            for idx, x, y in cl_dataset.data_loader[phase][task_id]: 
                x, y = x.cuda(), y.cuda()
                logits = trainer.model.forward(x)
                pred_list.append(logits.argmax(-1).detach().cpu())
                y_list.append(y.detach().cpu())
        gold_line = torch.cat(y_list)
        pred_line = torch.cat(pred_list)

    trainer.model.train()

    acc, ordered_acc_score_dict = evaluate_tc(gold_line,pred_line,cl_dataset.CUR_CLASS[task_id])

    return acc, ordered_acc_score_dict

def evaluate_tc(gold_line,pred_line,class_list=None):
    '''
        Evaluate the acc, each classes' acc for TC

        Params:
            - gold_line: the ground truth labels
            - pred_line: the predicted labels
            - class_list: the class in the gold label set (default=None)

        Return:
            - acc: accuracy
            - ordered_acc_score_dict: the acc of each class (if class_list is not None)
    '''
    gold_line, pred_line = np.array(gold_line), np.array(pred_line)
    acc = accuracy_score(gold_line,pred_line)*100

    if class_list is not None:
        # compute acc score for each class
        ordered_acc_score_dict = dict()
        for c in class_list:
            cnt_one_class = np.sum(gold_line==c)
            correct_one_class = np.sum(np.logical_and(pred_line==c,gold_line==c))
            if cnt_one_class==0:
                class_acc = 0
            else:
                class_acc = correct_one_class/cnt_one_class*100
            ordered_acc_score_dict[c] = class_acc

        return acc, ordered_acc_score_dict
    else:
        return acc
# ==================================================================================================================


# ============================================= For Evaluating IC =================================================
def compute_random_result_ic(trainer: BaseTrainer, cl_dataset: Continual_Dataset, task_id: int, phase: str) -> dict:
    '''
        Compute the acc of random guess on all seen tasks

        Params: 
            - trainer: the model
            - cl_dataset: the continual dataset
            - task_id: the task id
            - phase: 'train','dev'or'test'

        Return:
            - random_result: {
                    'acc': (NUM_TASK,),
                }
    '''
    assert phase in ['train','test','dev']

    random_result = {}
    acc_list = []

    for t_id in range(task_id+1):
        
        y_list = []
        with torch.no_grad():
            for idx, x, y in cl_dataset.data_loader[phase][t_id]: 
                y_list.append(y.detach().cpu()) 
        y_list = torch.cat(y_list)

        ### calcuate acc score
        gold_line = list(y_list.numpy())
        pred_line = [random.choice(cl_dataset.CUR_CLASS[task_id]) for _ in gold_line]

        acc = evaluate_ic(gold_line, pred_line)

        acc_list.append(acc)

    random_result['acc'] = acc_list

    return random_result 

def evaluate_all_seen_task_ic(trainer: BaseTrainer, cl_dataset: Continual_Dataset, task_id: int, phase: str, is_mbpa: bool=False) -> dict:
    '''
        Evaluate the model in the trainer on all seen tasks

        Params: 
            - trainer: the model
            - cl_dataset: the continual dataset
            - task_id: the task id
            - phase: 'train','dev'or'test'
            - is_mbpa: if using test time adaptation

        Return:
            - result_dict: {
                    'Result_test_acc_0': 90,
                    'Result_test_classacc_0': {...},
                    'Result_test_acc_1': 95,
                    'Result_test_classacc_1': {...},
                    'Result_test_mean_acc': 92.5,
                }

    '''
    assert phase in ['train','test','dev']

    result_dict = {}
    acc_list = []

    for t_id in range(task_id+1):

        acc, ordered_acc_score_dict = evaluate_current_task_ic(trainer, cl_dataset, t_id, phase, is_mbpa)

        result_dict['Result_%s_acc_%d'%(phase,t_id)] = acc
        result_dict['Result_%s_classacc_%d'%(phase,t_id)] = ordered_acc_score_dict

        acc_list.append(acc)

    result_dict['Result_%s_mean_acc'%(phase)] = np.mean(acc_list)

    return result_dict 

        
def evaluate_current_task_ic(trainer: BaseTrainer, cl_dataset: Continual_Dataset, task_id: int, phase: str, is_mbpa: bool=False) -> dict:
    '''
        Evaluate the model in the trainer on the current task

        Params: 
            - trainer: the model
            - cl_dataset: the continual dataset
            - task_id: the task id
            - phase: 'train','dev'or'test'
            - is_mbpa: if using test time adaptation

        Return:
            - acc: accuracy
            - ordered_acc_score_dict: the acc of each class
    '''

    assert phase in ['train','test','dev']

    trainer.model.eval()

    if is_mbpa and task_id>0:
        gold_line, pred_line = trainer.meta_mbpa_predict(cl_dataset.data_loader[phase][task_id], task_id)
    else:
        y_list = []
        pred_list = []
        with torch.no_grad():
            for idx, x, y in cl_dataset.data_loader[phase][task_id]: 
                x, y = x.cuda(), y.cuda()
                logits = trainer.model.forward(x)
                pred_list.append(logits.argmax(-1).detach().cpu())
                y_list.append(y.detach().cpu())
        gold_line = torch.cat(y_list)
        pred_line = torch.cat(pred_list)

    trainer.model.train()

    acc, ordered_acc_score_dict = evaluate_ic(gold_line,pred_line,cl_dataset.CUR_CLASS[task_id])

    return acc, ordered_acc_score_dict

def evaluate_ic(gold_line,pred_line,class_list=None):
    '''
        Evaluate the acc, each classes' acc for TC

        Params:
            - gold_line: the ground truth labels
            - pred_line: the predicted labels
            - class_list: the class in the gold label set (default=None)

        Return:
            - acc: accuracy
            - ordered_acc_score_dict: the acc of each class (if class_list is not None)
    '''
    gold_line, pred_line = np.array(gold_line), np.array(pred_line)
    acc = accuracy_score(gold_line,pred_line)*100

    if class_list is not None:
        # compute acc score for each class
        ordered_acc_score_dict = dict()
        for c in class_list:
            cnt_one_class = np.sum(gold_line==c)
            correct_one_class = np.sum(np.logical_and(pred_line==c,gold_line==c))
            if cnt_one_class==0:
                class_acc = 0
            else:
                class_acc = correct_one_class/cnt_one_class*100
            ordered_acc_score_dict[c] = class_acc

        return acc, ordered_acc_score_dict
    else:
        return acc
# ==================================================================================================================

# ======================================== For General Incremental Learning ==========================================
def compute_forgetting(result_matrix):
    '''
        Compute the forgetting for continual learning

        Params:
            - result_matrix: (NUM_TASK,NUM_TASK); the result for each task in every CL step

        Returns:
            - fgt: float
    '''
    NUM_TASK = result_matrix.shape[0]
    fgt_list = []
    for t_id in range(0,NUM_TASK-1):
        fgt_list.append(np.max(result_matrix[:,t_id])-result_matrix[-1,t_id])

    return np.mean(fgt_list)

def compute_backward_transfer(result_matrix):
    '''
        Compute the backward transfer for continual learning

        Params:
            - result_matrix: (NUM_TASK,NUM_TASK); the result for each task in every CL step

        Returns:
            - bwt: float
    '''
    NUM_TASK = result_matrix.shape[0]
    bwt_list = []
    for t_id in range(0,NUM_TASK-1):
        bwt_list.append(result_matrix[-1,t_id]-result_matrix[t_id,t_id])

    return np.mean(bwt_list)

def compute_forward_transfer(result_matrix, random_result):
    '''
        Compute the forward transfer for continual learning

        Params:
            - result_matrix: (NUM_TASK,NUM_TASK); the result for each task in every CL step
            - random_result: (NUM_TASK); the result of random guess for each task

        Returns:
            - fwt: float
    '''
    NUM_TASK = result_matrix.shape[0]
    fwt_list = []
    for t_id in range(1,NUM_TASK):
        fwt_list.append(result_matrix[t_id-1,t_id]-random_result[t_id])

    return np.mean(fwt_list)
# ==================================================================================================================
